# --*-- coding: UTF-8 -*-

import tensorflow as tf


# 使用该接口必须右tf.Session的环境
def load_image(path):
    image_raw_data_jpg = tf.gfile.FastGFile(path, 'rb').read()
    img_data_jpg = tf.image.decode_jpeg(image_raw_data_jpg)
    # [height, weight, channel]
    resized_image = tf.image.resize_images(img_data_jpg, [224, 224], tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    img = resized_image.eval()

    #灰度为1的情况
    dims = img.shape
    channel = 2
    if dims[channel] == 1:
        img = tf.image.grayscale_to_rgb(resized_image).eval()

    return img.reshape(-1, 224, 224, 3)

def main():
    with tf.Session() as sess:
        print (load_image('../data/img/train2014/COCO_train2014_000000004968.jpg').shape)

if __name__ == '__main__':
    main()


